{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 6: A Journey Into Sound\n",
    "\n",
    "(note: uses PyTorch 1.6 and torchaudio 0.6.0)\n",
    "\n",
    "Download and extract the ESC-50 files from https://github.com/karolpiczak/ESC-50#download"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython.display as display\n",
    "import librosa\n",
    "import librosa.display\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torchaudio\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "from pathlib import Path\n",
    "from PIL import Image\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import models, transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\"cpu\"):\n",
    "    for epoch in range(1, epochs+1):\n",
    "        training_loss = 0.0\n",
    "        valid_loss = 0.0\n",
    "        model.train()\n",
    "        for batch in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            output = model(inputs)\n",
    "            loss = loss_fn(output, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            training_loss += loss.data.item() * inputs.size(0)\n",
    "        training_loss /= len(train_loader.dataset)\n",
    "        \n",
    "        model.eval()\n",
    "        num_correct = 0 \n",
    "        num_examples = 0\n",
    "        for batch in val_loader:\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            output = model(inputs)\n",
    "            targets = targets.to(device)\n",
    "            loss = loss_fn(output,targets) \n",
    "            valid_loss += loss.data.item() * inputs.size(0)\n",
    "            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)\n",
    "            num_correct += torch.sum(correct).item()\n",
    "            num_examples += correct.shape[0]\n",
    "        valid_loss /= len(val_loader.dataset)\n",
    "\n",
    "        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,\n",
    "        valid_loss, num_correct / num_examples))\n",
    "        \n",
    "def find_lr(model, loss_fn, optimizer, train_loader, init_value=1e-8, final_value=10.0, device=\"cpu\"):\n",
    "    number_in_epoch = len(train_loader) - 1\n",
    "    update_step = (final_value / init_value) ** (1 / number_in_epoch)\n",
    "    lr = init_value\n",
    "    optimizer.param_groups[0][\"lr\"] = lr\n",
    "    best_loss = 0.0\n",
    "    batch_num = 0\n",
    "    losses = []\n",
    "    log_lrs = []\n",
    "    for data in train_loader:\n",
    "        batch_num += 1\n",
    "        inputs, targets = data\n",
    "        inputs = inputs.to(device)\n",
    "        targets = targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = loss_fn(outputs, targets)\n",
    "\n",
    "        # Crash out if loss explodes\n",
    "\n",
    "        if batch_num > 1 and loss > 4 * best_loss:\n",
    "            if(len(log_lrs) > 20):\n",
    "                return log_lrs[10:-5], losses[10:-5]\n",
    "            else:\n",
    "                return log_lrs, losses\n",
    "\n",
    "        # Record the best loss\n",
    "\n",
    "        if loss < best_loss or batch_num == 1:\n",
    "            best_loss = loss\n",
    "\n",
    "        # Store the values\n",
    "        losses.append(loss.item())\n",
    "        log_lrs.append((lr))\n",
    "\n",
    "        # Do the backward pass and optimize\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # Update the lr for the next step and store\n",
    "\n",
    "        lr *= update_step\n",
    "        optimizer.param_groups[0][\"lr\"] = lr\n",
    "    if(len(log_lrs) > 20):\n",
    "        return log_lrs[10:-5], losses[10:-5]\n",
    "    else:\n",
    "        return log_lrs, losses        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ESC-50 Dataset & DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ESC50(Dataset):\n",
    "    def __init__(self,path):\n",
    "        # Get directory listing from path\n",
    "        files = Path(path).glob('*.wav')\n",
    "        # Iterate through the listing and create a list of tuples (filename, label)\n",
    "        self.items = [(str(f),f.name.split(\"-\")[-1].replace(\".wav\",\"\")) for f in files]\n",
    "        self.length = len(self.items)\n",
    "    def __getitem__(self, index):\n",
    "        filename, label = self.items[index]\n",
    "        audioTensor, rate = torchaudio.load(filename)\n",
    "        return (audioTensor, int(label))     \n",
    "    def __len__(self):\n",
    "        return self.length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device=\"cuda\"\n",
    "bs=64\n",
    "PATH_TO_ESC50 = Path.cwd() / 'esc50'\n",
    "\n",
    "train_esc50 = ESC50(PATH_TO_ESC50 / \"train\")\n",
    "valid_esc50 = ESC50(PATH_TO_ESC50 / \"valid\")\n",
    "test_esc50  = ESC50(PATH_TO_ESC50 / \"test\")\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_esc50, batch_size = bs, shuffle = True)\n",
    "valid_loader = torch.utils.data.DataLoader(valid_esc50, batch_size = bs, shuffle = True)\n",
    "test_loader  = torch.utils.data.DataLoader(test_esc50, batch_size = bs, shuffle = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## M5-based CNN AudioNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AudioNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AudioNet, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(100, 128, kernel_size=5, stride=4)\n",
    "        self.bn1 = nn.BatchNorm1d(128)\n",
    "        self.pool1 = nn.MaxPool1d(4)\n",
    "        self.conv2 = nn.Conv1d(128, 128, 3)\n",
    "        self.bn2 = nn.BatchNorm1d(128)\n",
    "        self.pool2 = nn.MaxPool1d(4)\n",
    "        self.conv3 = nn.Conv1d(128, 256, 3)\n",
    "        self.bn3 = nn.BatchNorm1d(256)\n",
    "        self.pool3 = nn.MaxPool1d(4)\n",
    "        self.conv4 = nn.Conv1d(256, 512, 3)\n",
    "        self.bn4 = nn.BatchNorm1d(512)\n",
    "        self.pool4 = nn.MaxPool1d(4)\n",
    "        self.fc1 = nn.Linear(512, 50)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.unsqueeze(-1).view(-1, 100, 2205)\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(self.bn1(x))\n",
    "        x = self.pool1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(self.bn2(x))\n",
    "        x = self.pool2(x)\n",
    "        x = self.conv3(x)\n",
    "        x = F.relu(self.bn3(x))\n",
    "        x = self.pool3(x)\n",
    "        x = self.conv4(x)\n",
    "        x = F.relu(self.bn4(x))\n",
    "        x = self.pool4(x)\n",
    "        x = x.squeeze(-1)\n",
    "        x = self.fc1(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "audionet = AudioNet()\n",
    "audionet.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find learning rate & train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(audionet.state_dict(), \"audionet.pth\")\n",
    "optimizer = optim.Adam(audionet.parameters(), lr=0.001)\n",
    "logs,losses = find_lr(audionet, nn.CrossEntropyLoss(), optimizer, train_loader, device=device)\n",
    "\n",
    "plt.plot(logs,losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-5\n",
    "audionet.load_state_dict(torch.load(\"audionet.pth\"))\n",
    "optimizer = optim.Adam(audionet.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(audionet, optimizer, torch.nn.CrossEntropyLoss(),train_loader, valid_loader, epochs=20, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Spectrograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_data, sr = librosa.load(\"ESC-50/train/1-100032-A-0.wav\", sr=None)\n",
    "spectrogram = librosa.feature.melspectrogram(sample_data, sr=sr)\n",
    "log_spectrogram = librosa.power_to_db(spectrogram, ref=np.max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def precompute_spectrograms(path, dpi=50):\n",
    "    files = Path(path).glob('*.wav')\n",
    "    for filename in files:\n",
    "        audio_tensor, sr = librosa.load(filename, sr=None)\n",
    "        spectrogram = librosa.feature.melspectrogram(audio_tensor, sr=sr)\n",
    "        log_spectrogram = librosa.power_to_db(spectrogram, ref=np.max)\n",
    "        librosa.display.specshow(log_spectrogram, sr=sr, x_axis='time', y_axis='mel')\n",
    "        plt.gcf().savefig(\"{}{}_{}.png\".format(filename.parent,dpi,filename.name), dpi=dpi)\n",
    "\n",
    "PATH_ESC50_TRAIN = PATH_TO_ESC50 / \"train\"\n",
    "PATH_ESC50_VALID = PATH_TO_ESC50 / \"valid\"\n",
    "\n",
    "precompute_spectrograms(PATH_ESC50_TRAIN)\n",
    "precompute_spectrograms(PATH_ESC50_VALID)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrecomputedESC50(Dataset):\n",
    "    def __init__(self,path,dpi=50, img_transforms=None):\n",
    "        files = Path(path).glob('{}{}*.wav.png'.format(path.name, dpi))\n",
    "        self.items = [(f,int(f.name.split(\"-\")[-1].replace(\".wav.png\",\"\"))) for f in files]\n",
    "        self.length = len(self.items)\n",
    "        if img_transforms == None:\n",
    "            self.img_transforms = transforms.Compose([transforms.ToTensor()])\n",
    "        else:\n",
    "            self.img_transforms = img_transforms\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        filename, label = self.items[index]\n",
    "        img = Image.open(filename).convert('RGB')\n",
    "        return (self.img_transforms(img), label)\n",
    "            \n",
    "    def __len__(self):\n",
    "        return self.length"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pretrained ResNet50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spec_resnet = models.resnet50(pretrained=True)\n",
    "\n",
    "for param in spec_resnet.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "spec_resnet.fc = nn.Sequential(nn.Linear(spec_resnet.fc.in_features,500),\n",
    "                               nn.ReLU(),\n",
    "                               nn.Dropout(), nn.Linear(500,50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "esc50pre_train = PrecomputedESC50(PATH_ESC50_TRAIN,\n",
    "                                  img_transforms=transforms.Compose([\n",
    "                                      transforms.ToTensor(),\n",
    "                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                                           std=[0.229, 0.224, 0.225])])\n",
    ")\n",
    "\n",
    "esc50pre_valid = PrecomputedESC50(PATH_ESC50_VALID,\n",
    "                                  img_transforms=transforms.Compose([\n",
    "                                      transforms.ToTensor(),\n",
    "                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                                           std=[0.229, 0.224, 0.225])])\n",
    ")\n",
    "\n",
    "esc50_train_loader = torch.utils.data.DataLoader(esc50pre_train, bs, shuffle=True)\n",
    "esc50_val_loader = torch.utils.data.DataLoader(esc50pre_valid, bs, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spec_resnet.to(device) \n",
    "torch.save(spec_resnet.state_dict(), \"spec_resnet.pth\")\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(spec_resnet.parameters(), lr=lr)\n",
    "logs,losses = find_lr(spec_resnet, loss_fn, optimizer, esc50_train_loader, device=device)\n",
    "plt.plot(logs, losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spec_resnet.load_state_dict(torch.load(\"spec_resnet.pth\"))\n",
    "optimizer = optim.Adam([\n",
    "                        {'params': spec_resnet.conv1.parameters()},\n",
    "                        {'params': spec_resnet.bn1.parameters()},\n",
    "                        {'params': spec_resnet.relu.parameters()},\n",
    "                        {'params': spec_resnet.maxpool.parameters()},\n",
    "                        {'params': spec_resnet.layer1.parameters(), 'lr': 1e-4},\n",
    "                        {'params': spec_resnet.layer2.parameters(), 'lr': 1e-4},\n",
    "                        {'params': spec_resnet.layer3.parameters(), 'lr': 1e-4},\n",
    "                        {'params': spec_resnet.layer4.parameters(), 'lr': 1e-4},\n",
    "                        {'params': spec_resnet.avgpool.parameters(), 'lr': 1e-4},\n",
    "                        {'params': spec_resnet.fc.parameters(), 'lr': 1e-8}\n",
    "                        ], lr=1e-2)\n",
    "\n",
    "train(spec_resnet, optimizer, nn.CrossEntropyLoss(), esc50_train_loader, esc50_val_loader, epochs=5, device=device)\n",
    "\n",
    "for param in spec_resnet.parameters():\n",
    "    param.requires_grad = True\n",
    "\n",
    "train(spec_resnet, optimizer, nn.CrossEntropyLoss(), esc50_train_loader, esc50_val_loader, epochs=5, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Augmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ESC50WithPitchChange(Dataset):\n",
    "\n",
    "    def __init__(self,path):\n",
    "        # Get directory listing from path\n",
    "        files = Path(path).glob('*.wav')\n",
    "        # Iterate through the listing and create a list of tuples (filename, label)\n",
    "        self.items = [(f,f.name.split(\"-\")[-1].replace(\".wav\",\"\")) for f in files]\n",
    "        self.length = len(self.items)\n",
    "        self.E = torchaudio.sox_effects.SoxEffectsChain()\n",
    "        self.E.append_effect_to_chain(\"pitch\", [0.5])\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        filename, label = self.items[index]\n",
    "        self.E.set_input_file(filename)\n",
    "        audio_tensor, sample_rate = self.E.sox_build_flow_effects()\n",
    "        return audio_tensor, label\n",
    "        \n",
    "    def __len__(self):\n",
    "        return self.length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FrequencyMask(object):\n",
    "    \"\"\"\n",
    "      Example:\n",
    "        >>> transforms.Compose([\n",
    "        >>>     transforms.ToTensor(),\n",
    "        >>>     FrequencyMask(max_width=10, use_mean=False),\n",
    "        >>> ])\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, max_width, use_mean=True):\n",
    "        self.max_width = max_width\n",
    "        self.use_mean = use_mean\n",
    "\n",
    "    def __call__(self, tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            tensor (Tensor): Tensor image of \n",
    "            size (C, H, W) where the frequency \n",
    "            mask is to be applied.\n",
    "\n",
    "        Returns:\n",
    "            Tensor: Transformed image with Frequency Mask.\n",
    "        \"\"\"\n",
    "        start = random.randrange(0, tensor.shape[2])\n",
    "        end = start + random.randrange(1, self.max_width)\n",
    "        if self.use_mean:\n",
    "            tensor[:, start:end, :] = tensor.mean()\n",
    "        else:\n",
    "            tensor[:, start:end, :] = 0\n",
    "        return tensor\n",
    "\n",
    "    def __repr__(self):\n",
    "        format_string = self.__class__.__name__ + \"(max_width=\"\n",
    "        format_string += str(self.max_width) + \")\"\n",
    "        format_string += 'use_mean=' + (str(self.use_mean) + ')')\n",
    "\n",
    "        return format_string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms.Compose([FrequencyMask(max_width=10, use_mean=False),\n",
    "transforms.ToPILImage()])(torch.rand(3,250,200))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TimeMask(object):\n",
    "    \"\"\"\n",
    "      Example:\n",
    "        >>> transforms.Compose([\n",
    "        >>>     transforms.ToTensor(),\n",
    "        >>>     TimeMask(max_width=10, use_mean=False),\n",
    "        >>> ])\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, max_width, use_mean=True):\n",
    "        self.max_width = max_width\n",
    "        self.use_mean = use_mean\n",
    "\n",
    "    def __call__(self, tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            tensor (Tensor): Tensor image of \n",
    "            size (C, H, W) where the time mask \n",
    "            is to be applied.\n",
    "\n",
    "        Returns:\n",
    "            Tensor: Transformed image with Time Mask.\n",
    "        \"\"\"\n",
    "        start = random.randrange(0, tensor.shape[1])\n",
    "        end = start + random.randrange(0, self.max_width)\n",
    "        if self.use_mean:\n",
    "            tensor[:, :, start:end] = tensor.mean()\n",
    "        else:\n",
    "            tensor[:, :, start:end] = 0\n",
    "        return tensor\n",
    "\n",
    "    def __repr__(self):\n",
    "        format_string = self.__class__.__name__ + \"(max_width=\"\n",
    "        format_string += str(self.max_width) + \")\"\n",
    "        format_string += 'use_mean=' + (str(self.use_mean) + ')')\n",
    "        return format_string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms.Compose([TimeMask(max_width=10, use_mean=False),\n",
    "transforms.ToPILImage()])(torch.rand(3,250,200))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrecomputedTransformESC50(Dataset):\n",
    "    def __init__(self, path, max_freqmask_width, max_timemask_width, use_mean=True, dpi=50):\n",
    "        files = Path(path).glob('{}*.wav.png'.format(dpi))\n",
    "        self.items = [(f,f.name.split(\"-\")[-1].replace(\".wav.png\",\"\")) for f in files]\n",
    "        self.length = len(self.items)\n",
    "        self.max_freqmask_width = max_freqmask_width\n",
    "        self.max_timemask_width = max_timemask_width\n",
    "        self.use_mean = use_mean\n",
    "        self.img_transforms = transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.RandomApply([FrequencyMask(self.max_freqmask_width, self.use_mean)], p=0.5),\n",
    "            transforms.RandomApply([TimeMask(self.max_timemask_width, self.use_mean)], p=0.5)\n",
    "])\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        filename, label = self.items[index]\n",
    "        img = Image.open(filename)\n",
    "        return (self.img_transforms(img), label)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return self.length"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}